package server

import (
	"encoding/binary"
	"fmt"

	dmysql "github.com/go-sql-driver/mysql"
	"github.com/pingcap/tidb/mysql"

	"github.com/zeast/logs"
)

//Tx is the transaction interface.
type Tx interface {
	Commit() error
	Rollback() error
}

type mysqlTx struct {
	*mysqlConn
}

var _ Tx = &mysqlTx{}

func (tx *mysqlTx) Close() error {
	return tx.Rollback()
}

func (tx *mysqlTx) Exec(data []byte) ([]byte, error) {
	if tx.mysqlConn.netConn == nil {
		return nil, mysql.ErrBadConn
	}
	// Send command
	result, err := tx.mysqlConn.Exec(data)
	if err != nil {
		return nil, err
	}

	return result, err
}

func (tx *mysqlTx) Query(data []byte) ([][]byte, error) {
	if tx.mysqlConn.netConn == nil {
		return nil, mysql.ErrBadConn
	}
	// Send command
	result, err := tx.mysqlConn.Query(data)
	if err != nil {
		return nil, err
	}
	return result, err
}

func (tx *mysqlTx) Commit() (err error) {
	if tx.mysqlConn == nil || tx.mysqlConn.netConn == nil {
		return dmysql.ErrInvalidConn
	}
	err = tx.mysqlConn.writeCommandPacketStr(mysql.ComQuery, "COMMIT")
	tx.mysqlConn = nil
	return
}

func (tx *mysqlTx) Rollback() (err error) {
	if tx.mysqlConn == nil || tx.mysqlConn.netConn == nil {
		return dmysql.ErrInvalidConn
	}
	logs.Debug("ROLLBACK")
	err = tx.mysqlConn.writeCommandPacketStr(mysql.ComQuery, "ROLLBACK")
	tx.mysqlConn = nil
	return
}

type mysqlStmt struct {
	id uint32
	*mysqlConn
	paramCount int
	// columns    []mysqlField
}

func (ms *mysqlStmt) Close() error {
	ms.Lock()
	defer ms.Unlock()
	if ms.mysqlConn == nil || ms.mysqlConn.netConn == nil {
		logs.Debug("netConn had closed.")
		return mysql.ErrBadConn
	}
	err := ms.mysqlConn.writeCommandPacketUint32(mysql.ComStmtClose, ms.id)
	return err
}

func (ms *mysqlStmt) NumInput() int {
	return ms.paramCount
}

// Prepare Result Packets
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
func (ms *mysqlStmt) readPrepareResultPacket() ([]byte, uint16, error) {
	data, err := ms.io.readPacket()
	if err == nil {
		// packet indicator [1 byte]
		if data[0] != mysql.OKHeader {
			return data, 0, ms.handleErrorPacket(data)
		}

		// statement id [4 bytes]
		ms.id = binary.LittleEndian.Uint32(data[1:5])

		// Column count [16 bit uint]
		columnCount := binary.LittleEndian.Uint16(data[5:7])

		// Param count [16 bit uint]
		ms.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))

		// Reserved [8 bit]

		// Warning count [16 bit uint]
		if !ms.strict {
			return data, columnCount, nil
		}

		// Check for warnings count > 0, only available in MySQL > 4.1
		if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
			return data, columnCount, fmt.Errorf("mc.getWarnings()")
		}
		return data, columnCount, nil
	}
	return data, 0, err
}
